open Ast

module CharMap = Map.Make(String)

(* Functions to print AST *)
let string_of_op = function
    Add       -> "+"
  | Mult      -> "*"
  | Sub       -> "-"
  | Div       -> "/"
  | Mod       -> "%"
  | Caret     -> "^"
  | Append    -> "@"
  | And       -> "&&"
  | Or        -> "||"
  | Lte       -> "<="
  | Gte       -> ">="
  | Neq       -> "!="
  | Eq        -> "=="
  | Lt        -> "<"
  | Gt        -> ">"
  | Cons      -> "::"
;;

let string_of_uop = function
    Not       -> "!"
  | Neg       -> "-"
;;

let rec string_of_expr = function
    VoidLit             -> "VOID"
  | NoExpr              -> "\n"
  | NumLit(n)           -> string_of_float n
  | StringLit(s)        -> s
  | BoolLit(true)       -> "True"
  | BoolLit(false)      -> "False"
  | ListLit(l)          -> "[" ^ String.concat ", " (List.map string_of_expr l) ^ "]"
  | DictLit(l)          -> let pairs = List.map (fun (k, v) -> string_of_expr k ^ ": " ^ string_of_expr v) l in "{" ^ (String.concat ", " pairs) ^ "}"
  | FunLit(args, s)     -> "lambda(" ^ String.concat ", " args ^ "):\n" ^ string_of_expr s
  | Val(s)              -> s
  | Binop(e1, op, e2)   -> string_of_expr e1 ^ " " ^ string_of_op op ^ " " ^ string_of_expr e2 
  | Unop(op, e)         -> string_of_uop op ^ " " ^ string_of_expr e 
  | Assign(v, e)        -> v ^ " = " ^ string_of_expr e
  | Call(Val(f), args)  -> f ^ "(" ^ String.concat ", " (List.map string_of_expr args) ^ ")"
  | Call(_, args)       -> "==== THIS SHOULD NEVER HAPPEN ===="
  | Element(s, e)       -> s ^ "[" ^ string_of_expr e ^ "]"
  | Block(es)           -> String.concat "\n" (List.map string_of_expr es) ^ "\nend\n"
  | If(esl, s)          -> "if " ^ string_of_expr (fst(List.hd esl)) ^ ":\n" ^ 
                           string_of_expr (snd(List.hd esl)) ^ 
                           String.concat "\n" (List.map string_of_elif (List.tl esl)) ^
                           "else:\n" ^ string_of_expr s
and string_of_elif el = 
    "elif " ^ string_of_expr (fst el) ^ ":\n" ^ string_of_expr (snd el)
;;

let string_of_program prog =  
		String.concat "\n" (List.map string_of_expr prog)
;;

let string_of_type t =
  let rec helper t chr map =
    match t with
      TNum    -> "num", chr, map
    | TString -> "string", chr, map
    | TBool   -> "bool", chr, map
    | TAny    -> "any", chr, map
    | TVoid   -> "void", chr, map
    | T(x)    ->
      let gen_chr, new_chr, new_map = if CharMap.mem x map
        then Char.escaped (Char.chr (CharMap.find x map)), chr, map
        else
          let c = Char.escaped (Char.chr chr) in
          c, (chr + 1), CharMap.add x chr map
      in
      Printf.sprintf "%s" gen_chr, new_chr, new_map
    | TList(t) ->
      let st, c, m = helper t chr map in
      (Printf.sprintf "list %s" st), c, m
    | TDict(kt, vt) ->
      let st1, c1, m1 = helper kt chr map in
      let st2, c2, m2 = helper vt c1 m1 in
      (Printf.sprintf "<%s:%s>" st1 st2), c2, m2
    | TFun(args_type, ret_type) ->
      let fold_func acc arg = 
        let (c, m) = snd(acc) in
        let argt, c1, m1 = helper arg c m in 
        (fst(acc) @ [argt], (c1, m1)) in
      let sargs, (c, m) = List.fold_left fold_func ([], (chr, map)) args_type in
      let rs, c, m = helper ret_type c m in
      let sargs = String.concat ", " sargs in
      Printf.sprintf "(%s) -> %s" sargs rs, c, m
  in
  let s, _, _ = helper t 65 CharMap.empty in s
;;

let rec string_of_aexpr ae =
  match ae with
    ABoolLit(b, t)          -> Printf.sprintf "(%s: %s)" (string_of_bool b) (string_of_type t)
  | ANumLit(x, t)           -> Printf.sprintf "(%s: %s)" (string_of_float x) (string_of_type t)
  | AStringLit(s, t)        -> Printf.sprintf "(%s: %s)" s (string_of_type t)
  | AVoidLit(t)             -> Printf.sprintf "(VOID)"
  | AVal(s, t)              -> Printf.sprintf "(%s: %s)" s (string_of_type t)
  | ABinop(e1, op, e2, t)   -> Printf.sprintf "(%s %s %s: %s)" (string_of_aexpr e1) (string_of_op op) (string_of_aexpr e2) (string_of_type t)
  | AUnop(op, e, t)         -> Printf.sprintf "(%s %s: %s)" (string_of_aexpr e) (string_of_uop op) (string_of_type t)      
  | AAssign(id, t1, e, t2)  -> Printf.sprintf "%s: %s = %s : %s" id (string_of_type t1) (string_of_aexpr e) (string_of_type t2)
  | AListLit(aes, t)        -> Printf.sprintf "[%s]:%s" (String.concat "," (List.map string_of_aexpr aes)) (string_of_type t)
  | ADictLit(kvpairs, t)    -> Printf.sprintf "dict < >: %s" (string_of_type t)
  | AIf(apairs, ae, t)      -> Printf.sprintf "if {} elif {} else {}: %s" (string_of_type t)
  | ABlock(aes, t)          -> Printf.sprintf "{ %s }" (String.concat "\n" (List.map string_of_aexpr aes))
  | ACall(afn, aargs, t)    -> Printf.sprintf "%s(%s) : %s" (string_of_aexpr afn) (String.concat "," (List.map string_of_aexpr aargs)) (string_of_type t)
  | AElement(id, ae, t)     -> Printf.sprintf "%s[%s] : %s" id (string_of_aexpr ae) (string_of_type t)
  | AFunLit(ids, body, t)   -> begin
      let args_with_types, ret_type = (match t with
            TFun(args_type, ret_type) -> List.combine ids args_type, ret_type
          | _ -> raise (Failure("not a valid function"))) in
      let fargs = String.concat ", " (List.map (fun (id, typ) -> id ^ " : " ^ string_of_type typ) args_with_types) in
      let fsig = "(" ^ fargs ^ ")" ^ " : " ^ (string_of_type ret_type) in
      let fbody = string_of_aexpr body in
      String.concat " " ["lambda"; fsig; "="; "{"; fbody; "}"]
    end
;;

let string_of_aprogram aprog =  
    String.concat "\n" (List.map string_of_aexpr aprog)
;;